package tree

import (
	"fmt"
	"gitee.com/kklt1996/data-structure/common"
	"strings"
)

/*
	线段树(区间树)
	解决对固定个数的元素,指定区间进行查询问题
	此线段树用于计算[i,j]区间内元素的和
*/
type ArraySegmentTree struct {
	/*
		存放线段树的数组
		因为使用满二叉树存储线段树,满二叉树最后一层元素个数为n的话,满二叉树元素个数为2n-1
		因此有n个元素的线段树,看成满二叉树,最后一层元素个数最多是2n,满二叉树个数为4n-2
		数组的大小等于len(data)的四倍,保证能存储下线段树.
	*/
	tree []interface{}

	/*
		存放源数据的数组
	*/
	data []interface{}

	/*
		聚合器,将左右子节点聚合成根节点
	*/
	aggregator func(leftChildValue interface{}, rightChildValue interface{}) interface{}
}

/*
	获取线段树区间大小
*/
func (tree ArraySegmentTree) GetSize() int {
	return len(tree.data)
}

/*
	O(log(n))
	对区间[l,r]进行查询
*/
func (tree ArraySegmentTree) Query(queryL int, queryR int) (interface{}, error) {
	if queryL < 0 || queryL > len(tree.data) || queryR < 0 || queryR > len(tree.data) || queryR < queryL {
		return nil, common.IndexError{}
	}
	return tree.query(0, 0, len(tree.data)-1, queryL, queryR), nil
}

func (tree ArraySegmentTree) query(treeIndex int, l int, r int, queryL int, queryR int) interface{} {
	// treeIndex表示的区间等于要查询的区间,返回结果
	if l == queryL && r == queryR {
		return tree.tree[treeIndex]
	} else {
		leftChildIndex := tree.leftChildIndex(treeIndex)
		rightChildIndex := tree.rightChildIndex(treeIndex)
		mid := l + (r-l)/2
		if queryL >= mid+1 {
			// 要查询区间在左子树,从左子树查询
			return tree.query(rightChildIndex, mid+1, r, queryL, queryR)
		} else if queryR <= mid {
			// 要查询区间在右子树,从右子树查询
			return tree.query(leftChildIndex, l, mid, queryL, queryR)
		} else {
			// 要查询的区间分布在左子树和右子树,从左右子树聚合结果
			leftSegmentValue := tree.query(leftChildIndex, l, mid, queryL, mid)
			rightSegmentValue := tree.query(rightChildIndex, mid+1, r, mid+1, queryR)
			return tree.aggregator(leftSegmentValue, rightSegmentValue)
		}
	}
}

/*
	O(log(n))
	更新数组中的值
*/
func (tree *ArraySegmentTree) Set(i int, value interface{}) error {
	if i < 0 || i >= len(tree.data) {
		return common.IndexError{}
	}
	tree.data[i] = value
	// 对线段树进行更新操作
	tree.set(0, 0, len(tree.data)-1, i, value)
	return nil
}

/*
	在[l,r]范围内更新index位置元素为value后,更新线段树中范围包含index的节点的值
*/
func (tree *ArraySegmentTree) set(treeIndex int, l int, r int, index int, value interface{}) {
	if l == r {
		tree.tree[treeIndex] = tree.aggregator(value, nil)
		return
	}
	leftChildIndex := tree.leftChildIndex(treeIndex)
	rightChildIndex := tree.rightChildIndex(treeIndex)
	mid := l + (r-l)/2
	if index <= mid {
		// 如果在左子树的范围,就更新左子树
		tree.set(leftChildIndex, l, mid, index, value)
	} else {
		// 如果在右子树的范围,就更新右子树
		tree.set(rightChildIndex, mid+1, r, index, value)
	}
	// 对左右子树进行合并操作
	tree.tree[treeIndex] = tree.aggregator(tree.tree[leftChildIndex], tree.tree[rightChildIndex])
}

/*
	O(n)
	初始化线段树
*/
func (tree *ArraySegmentTree) init(treeIndex int, l int, r int) {
	if l == r {
		tree.tree[treeIndex] = tree.aggregator(tree.data[l], nil)
		return
	} else {
		// 计算左右孩子索引
		leftChildIndex := tree.leftChildIndex(treeIndex)
		rightChildIndex := tree.rightChildIndex(treeIndex)
		// 取两个数字的平均值,鉴于l+r可能会发生溢出的问题,可以使用l + (r-l)/2等价于(r+l)/2
		mid := l + (r-l)/2
		// 计算左右孩子的值
		tree.init(leftChildIndex, l, mid)
		tree.init(rightChildIndex, mid+1, r)
		// 根据左右孩子的结果计算根节点的结果
		tree.tree[treeIndex] = tree.aggregator(tree.tree[leftChildIndex], tree.tree[rightChildIndex])
	}
}

/*
	获取左孩子的索引
*/
func (tree ArraySegmentTree) leftChildIndex(index int) int {
	return 2*index + 1
}

/*
	获取右孩子的索引
*/
func (tree ArraySegmentTree) rightChildIndex(index int) int {
	return 2*index + 2

}

func (tree ArraySegmentTree) String() string {
	res := "["
	for _, v := range tree.tree {
		if v != nil {
			res += fmt.Sprintf("%v,", v)
		} else {
			res += "nil,"
		}
	}
	res = strings.TrimSuffix(res, ",")
	res += "]"
	return res
}
